# Copyright (c) 2023 Copyright holder of the paper "Revisiting Image Classifier Training for Improved Certified Robust Defense against Adversarial Patches" submitted to TMLR for review

# All rights reserved.

import math
import numpy as np
import torch
from torch import nn


def generate_masks(mask_size, stride, mask_set_size, img_size):
    masks = torch.zeros(1, mask_set_size, img_size, img_size)
    start_h, start_w = 0, 0
    mask_idx = 0

    for _ in range(int(np.sqrt(mask_set_size))):
        end_h = start_h + mask_size
        end_h = np.clip(end_h, 0, img_size)
        start_w = 0
        for _ in range(int(np.sqrt(mask_set_size))):
            end_w = start_w + mask_size
            end_w = np.clip(end_w, 0, img_size)
            masks[0, mask_idx, start_h:end_h, start_w:end_w] = 1
            start_w = start_w + stride
            mask_idx += 1

        start_h = start_h + stride
    return masks


# code taken from https://raw.githubusercontent.com/inspire-group/PatchCleanser/main/utils/cutout.py
class Cutout(object):
    """Randomly mask out one or more patches from an image.

    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask + (1-mask) * 0.5

        return img
